# Change working directory 
# Line 50 - Extraction of key DS QOIs
# Line 262- ML ESTIMATION USING DS SAMPLE
# Line 476 - BALANCE PLOTS
# Line 536 - DOUBLY ROBUST MAIN ANALYSES (note lines 320 onwards take approximately 2.5 hours to run)
# Line 2424 - Additional ML algorithms for Appendix A

rm(list = ls())
library(tidyverse)
library(caret)
library(haven)
library(stringr)
library(survey)
library(splines)
library(rlang)
library(lmtest)
library(sandwich)
library(broom)
library(MASS)
library(matrixStats)
library(naniar)
library(VIM)
library(mice)
library(ggpubr)
library(haven)
library(splines)
library(survey)
library(rbw)
library(gbm)
library(xgboost)
library(BART)
library(rlang)
library(rsample)
source("zmisc.R")
set.seed(666)

# Set working directory
#setwd("~/Dropbox (Harvard University)/Gov 2001 Rep Paper/R Scripts")
#setwd("~/Dropbox (Harvard University)/Gov_2001_Rep/R Scripts")

mytheme <- theme_minimal(base_size = 13) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)))
theme_set(mytheme)



################################################
########## (1) EXTRACTION OF DS QOIS ##########
################################################
d <- read_dta("Datasets/Downes-Sechser-Appendix-C.dta")

# Set theme
mytheme <- theme_minimal(base_size = 13) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)))
theme_set(mytheme)


# Variables and year dummies for each of the world war years
variables <- vars <- vars(majmaj_ks, majmin_ks, minmaj_ks, capshare_a_ks, contig_ks, swt_dyad, tau_lead_a, tau_lead_b,
                          territory_ks, government_ks, policy_ks, other_ks)

## Continuous
# capshare_a_ks - initiator's share of capabilities
# swt_dyad - alliance portfolio similarity - each ally’s military capabilities
# tau_lead_a - status quo evaluation of inititator - e similarity of a given state’s alliance portfolio to that of the most powerful state in the system
# tau_lead_b - status quo evaluation of target

## Dummies
# majmaj_ks: major power initiator - major power target 0 1 - 
# majmin_ks: major power initiator  - minor power target 0 1
# minmaj_ks: minor power initiator  - major power target 0 1

# contig_ks: contiguous 0 1
# territory_ks: territory 0 1
# government_ks: gov 0 1
# policy_ks: policy 0 1
# other_ks: other 0 1

year_dummies <- vars(dummy1914_ks, dummy1915_ks, dummy1916_ks, dummy1917_ks, dummy1918_ks, dummy1939_ks,
                     dummy1940_ks, dummy1941_ks, dummy1942_ks, dummy1943_ks, dummy1945_ks)            

ww_rhs <- map(c(vars, year_dummies), get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

no_ww_rhs <- map(vars, get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

# Variables used for logit analysis (bar world war year dummies) -and not all dem dummies - check!
logit_df_with_missings <- d %>% dplyr::select(!!!variables, failure_100_ks, democ_a_ks, democ_b_ks, demdem_ks )

# Note on variable coding:
# _a_ks : democratic initiator
# _b_ks : democratic target
# _dem_dem_ks : both democratic

# Profile of Schultz's original models (Schultz 2001, pp. 146-47)
# Model 1: all crises, World War years included, robust standard errors clustered on crisis
# Model 2: all crises, World War years excluded, robust standard errors clustered on crisis
# Model 3: bilateral crises only, World War years included
# Model 4: bilateral crises only, World War years excluded

# Table 5: Democracy measure: Schultz 2001 (also Table 4 in "The Illusion of Democratic Credibility" [p. 478])
tab_5_form_ww <- as.formula(expr(failure_100_ks ~ democ_a_ks + democ_b_ks + demdem_ks + !!ww_rhs))
tab_5_form_no_ww <- as.formula(expr(failure_100_ks ~ democ_a_ks + democ_b_ks + demdem_ks + !!no_ww_rhs))

mod1 <- glm(tab_5_form_ww, data = d, family = binomial())
tidy(coeftest(mod1, vcov = vcovHC(mod1, type="HC1")))[2,] # democracy treatment and robust SE 

mod2 <- glm(tab_5_form_no_ww, data = filter(d, worldwar_ks == 0), family = binomial())
tidy(coeftest(mod2, vcov = vcovHC(mod2, type="HC1")))[2,] # democracy treatment and robust SE 

mod3 <- glm(tab_5_form_ww, data = filter(d, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod3, vcov = vcovHC(mod3, type="HC1")))[2,] # democracy treatment and robust SE 

mod4 <- glm(tab_5_form_no_ww, data = filter(d, worldwar_ks == 0, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod4, vcov = vcovHC(mod4, type="HC1")))[2,] # democracy treatment and robust SE 


# Table 6: Democracy measure: Polity
tab_6_form_ww <- as.formula(expr(failure_100_ks ~ democpolity_a_ks + democpolity_b_ks + demdempolity_ks + !!ww_rhs))
tab_6_form_no_ww <- as.formula(expr(failure_100_ks ~ democpolity_a_ks + democpolity_b_ks + demdempolity_ks + !!no_ww_rhs))

mod5 <- glm(tab_6_form_ww, data = d, family = binomial())
tidy(coeftest(mod5, vcov = vcovHC(mod5, type="HC1")))[2,] # democracy treatment and robust SE 

mod6 <- glm(tab_6_form_no_ww, data = filter(d, worldwar_ks == 0), family = binomial())
tidy(coeftest(mod2, vcov = vcovHC(mod2, type="HC1")))[2,] # democracy treatment and robust SE 

mod7 <- glm(tab_6_form_ww, data = filter(d, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod3, vcov = vcovHC(mod3, type="HC1")))[2,] # democracy treatment and robust SE 

mod8 <- glm(tab_6_form_no_ww, data = filter(d, worldwar_ks == 0, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod4, vcov = vcovHC(mod4, type="HC1")))[2,] # democracy treatment and robust SE 


# Table 7: Democracy measure: Cheibub et al. 2010
# Note that data from Cheibub et al. (2010) extend from 1946–2001, thus obviating the need for differential treatment of the world war years.
tab_7_form_ww <- as.formula(expr(failure_100_ks ~ democ_cheibub_a + democ_cheibub_b + demdem_cheibub + !!ww_rhs))

mod9 <- glm(tab_7_form_ww, data = d, family = binomial())
tidy(coeftest(mod1, vcov = vcovHC(mod1, type="HC1")))[2,] # democracy treatment and robust SE 

mod10 <- glm(tab_7_form_ww, data = filter(d, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod2, vcov = vcovHC(mod2, type="HC1")))[2,] # democracy treatment and robust SE 


# Table 8: Democracy measure: Przeworski et al. 2000
# Note that updated data from Przeworski et al. (2000) extend from 1946–2001, thus obviating the need for differential treatment of the world war years.
tab_7_form_ww <- as.formula(expr(failure_100_ks ~ democ_przeworski_a + democ_przeworski_b + demdem_przeworski + !!ww_rhs))

mod11 <- glm(tab_7_form_ww, data = d, family = binomial())
tidy(coeftest(mod1, vcov = vcovHC(mod1, type="HC1")))[2,] # democracy treatment and robust SE 

mod12 <- glm(tab_7_form_ww, data = filter(d, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod2, vcov = vcovHC(mod2, type="HC1")))[2,] # democracy treatment and robust SE 


# Table 9: Democracy measure: Boix and Rosato 2001
tab_9_form_ww <- as.formula(expr(failure_100_ks ~ democ_boix_a + democ_boix_b + demdem_boix + !!ww_rhs))
tab_9_form_no_ww <- as.formula(expr(failure_100_ks ~ democ_boix_a + democ_boix_b + demdem_boix + !!no_ww_rhs))

mod13 <- glm(tab_9_form_ww, data = d, family = binomial())
tidy(coeftest(mod1, vcov = vcovHC(mod1, type="HC1")))[2,] # democracy treatment and robust SE 

mod14 <- glm(tab_9_form_no_ww, data = filter(d, worldwar_ks == 0), family = binomial())
tidy(coeftest(mod2, vcov = vcovHC(mod2, type="HC1")))[2,] # democracy treatment and robust SE 

mod15 <- glm(tab_9_form_ww, data = filter(d, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod3, vcov = vcovHC(mod3, type="HC1")))[2,] # democracy treatment and robust SE 

mod16 <- glm(tab_9_form_no_ww, data = filter(d, worldwar_ks == 0, bilateral_ks == 1), family = binomial())
tidy(coeftest(mod4, vcov = vcovHC(mod4, type="HC1")))[2,] # democracy treatment and robust SE 

# Tibble output for models 1-16
holder <- vector(mode = "list")
for (mod in 1:16) {
  
  name <- sym(paste0("mod", mod))
  df <- tibble("Democracy Indictor" = c("Initiator", "Target", "Both Democratic"),
               "Model" = c(paste0("Model_", mod)),
               "Estimate" =  eval(expr(tidy(coeftest(!!name, vcov = vcovHC(!!name, type="HC1")))))[2:4,2] %>% unlist(),
               "SE" =  eval(expr(tidy(coeftest(!!name, vcov = vcovHC(!!name, type="HC1")))))[2:4,3] %>% unlist())
  
  holder[[""]] <- df
}

df = bind_rows(holder)


######################################################
######## Plotting predicted probabilities #########
######################################################

logit_df <- drop_na(logit_df_with_missings)

# Predicted ATE when variables held at mean + median
invlogit <- function(x) {exp(x) / (1 + exp(x))} 

unconditional_first_diff <- function(other_covariates) {
  
  sim_coeff <-function(coeff, vcov, ndraws = 5000) {
    return(mvtnorm::rmvnorm(ndraws, mean = coeff, sigma = vcov))
  }
  sims <- sim_coeff(mod2$coeff,vcov(mod2)) %>% as.data.frame() %>% dplyr::select(sort(names(.)))
  
  data_sim_dem1 <- apply(dplyr::select(logit_df, -c(failure_100_ks, democ_a_ks)), MARGIN = 2, other_covariates, na.rm = TRUE) %>%
    as.data.frame() %>% t() %>% as.data.frame() %>%
    `colnames<-`(c(names(dplyr::select(logit_df, -c(failure_100_ks, democ_a_ks))))) %>%
    mutate(democ_a_ks = 1) %>% dplyr::select(sort(names(.)))
  
  data_sim_dem0 <- data_sim_dem1 %>% mutate(democ_a_ks = 0)
  
  ate <- (invlogit(cbind(1, as.matrix(data_sim_dem1)) %*% t(as.matrix(sims)))) - (invlogit(cbind(1, as.matrix(data_sim_dem0)) %*% t(as.matrix(sims))))
  
  output <- data.frame(point_est = apply(ate, MARGIN = 1, other_covariates),
                       p_min = apply(ate, MARGIN = 1, function(x) {quantile(x, 0.025)}),
                       p_max = apply(ate, MARGIN = 1, function(x) {quantile(x, 0.975)}),
                       other_cov = as.character(parse_expr(other_covariates)))
  return(output)
}

mean_median <- rbind(unconditional_first_diff("mean"), unconditional_first_diff("median"))

# Predicted ATE when variables held at marginal values
sim_coeff <-function(coeff, vcov, ndraws = 5000) {
  return(mvtnorm::rmvnorm(ndraws, mean = coeff, sigma = vcov))
}
sims <- sim_coeff(mod2$coeff,vcov(mod2)) %>% as.data.frame() %>% dplyr::select(sort(names(.)))
data_sim_dem1 <- dplyr::select(logit_df, -c(failure_100_ks, democ_a_ks)) %>% mutate(democ_a_ks = 1)
data_sim_dem0 <- dplyr::select(logit_df, -c(failure_100_ks, democ_a_ks)) %>% mutate(democ_a_ks = 0)
p1 <- apply(invlogit(cbind(1, as.matrix(data_sim_dem1)) %*% t(as.matrix(sims))), MARGIN = 2, mean, na.rm = TRUE)
p0 <- apply(invlogit(cbind(1, as.matrix(data_sim_dem0)) %*% t(as.matrix(sims))), MARGIN = 2, mean, na.rm = TRUE)
ate <- p1 - p0
marginal <- data.frame(point_est = mean(ate), p_min = quantile(ate, 0.025), p_max = quantile(ate, 0.975),  other_cov =  "marginal")

mean_median_marginal <- rbind(mean_median, marginal)
saveRDS(mean_median_marginal, file = "mean_median_marginal.RDS")

plot <-   ggplot(mean_median_marginal, aes(x = other_cov, y = point_est)) +
  geom_point(size = 3.35, color='grey50') +
  geom_errorbar(aes(ymin = p_min, ymax = p_max,
                    width = 0.5), show.legend = FALSE, color='grey50') +
  scale_x_discrete(labels=c("Mean", "Median", "Marginals")) +
  xlab("Level of Adjusted Covariates") +
  scale_y_continuous(name = "ATE Estimate and 95% \nSimulated CIs") +
  theme(axis.title.x = element_text(vjust=-0.8)) + 
  theme_minimal(base_size = 16)
plot
ggsave(path = "Figures", filename = "original_estimates_probs.jpg", width = 6, height = 4)





################################################
###### (2) ML ESTIMATION USING DS SAMPLE #######
################################################


rm(list = ls())
library(tidyverse)
library(caret)
library(haven)
library(stringr)
library(survey)
library(splines)
library(rlang)
library(lmtest)
library(sandwich)
library(broom)
library(MASS)
library(matrixStats)
library(naniar)
library(VIM)
library(mice)
library(ggpubr)
library(haven)
library(splines)
library(survey)
library(rbw)
library(gbm)
library(xgboost)
library(BART)
library(rlang)
library(rsample)
source("zmisc.R")
set.seed(666)

data <- read_dta("Datasets/Downes-Sechser-Appendix-C.dta")

# Variables and year dummies for each of the world war years
variables <- vars(majmaj_ks, majmin_ks, minmaj_ks, capshare_a_ks, contig_ks, swt_dyad, tau_lead_a, tau_lead_b,
                  territory_ks, government_ks, policy_ks, other_ks)

schultz_dems <- vars(democ_a_ks, democ_b_ks, demdem_ks)

year_dummies <- vars(dummy1914_ks, dummy1915_ks, dummy1916_ks, dummy1917_ks, dummy1918_ks, dummy1939_ks,
                     dummy1940_ks, dummy1941_ks, dummy1942_ks, dummy1943_ks, dummy1945_ks)            

ww_rhs <- map(c(variables, year_dummies), get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

no_ww_rhs <- map(variables, get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

schultz_dems_rhs <- map(schultz_dems, get_expr) %>%
  reduce(~ expr(!!.x + !!.y))

# Schultz democracy measure attempt: no world wars

d <- schultz_df <- filter(data, worldwar_ks == 0) %>% dplyr::select(!!!variables, !!!schultz_dems, failure_100_ks) %>% drop_na() %>%
  mutate(treat_num = as.double(democ_a_ks), treat = factor(democ_a_ks, labels = c("no", "yes")), outcome = factor(failure_100_ks, labels = c("no", "yes")),
         outcome_num = as.double(failure_100_ks))

names(schultz_df)
nrow(schultz_df) #198 non world war observations

factual <- counter <- schultz_df <- d
counter$treat_num <- 1-counter$treat_num


# Exploration of potential model dependency
#################################################################
############# ATE via debiased machine learning ###########
#################################################################
## NB using random forest for propensity score and outcome models

# excluded demdem_ks as i dont think should include varibae expanded interaction terms
a_formula <- as.formula(expr(treat ~ !!no_ww_rhs + democ_b_ks))

y_formula <- as.formula(expr(outcome ~ !!no_ww_rhs + democ_b_ks + democ_a_ks))


k <- 1
K <- 5
whole_folds <- createFolds(d$outcome, k = K) 
predicted_list <- vector(mode = "list", K)

###############################################
#  Use each fold as an estimation sample once 
################################################
for (k in 1:K) {
  
  cat("fold ", k, "\n")
  aux <- d[-whole_folds[[k]], ]
  main <- d[whole_folds[[k]], ]
  
  ## Data matrices for train() and predict()
  # Auxiliary sample models
  aux_X <-  model.matrix(y_formula, data = aux)[, -1]
  
  # Counterfactual main sample models
  main0 <- main1 <- main
  main0$treat <- 0   ## check name of variable required for RF prediction 
  main1$treat <- 1
  main0$treatyes <- 0
  main1$treatyes <- 1
  main_X_a <-  model.matrix(a_formula, data = main)[, -1]
  
  main_X_a0 <-  model.matrix(y_formula, data = main0)[, -1]
  main_X_a1 <-  model.matrix(y_formula, data = main1)[, -1]
  
  
  myFolds <- createFolds(aux$outcome, k = 5)
  
  ###############################
  # Propensity score models
  ###############################
  aControl <- trainControl(
    summaryFunction = mnLogLoss, 
    number = 5,
    method = "cv",
    classProbs = TRUE,
    savePredictions = TRUE,
    index = myFolds
  )
  
  # Random forest
  cat(" ps: rf", "\n")
  rf_ps <- train(
    a_formula,
    data = aux,
    method = "ranger", 
    trControl = aControl, 
    tuneLength = 3
  )
  
  main$rf_ps0 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "no"]
  main$rf_ps1 <- predict(rf_ps$finalModel, data = main, type = "response")$predictions[, "yes"]
  
  #  Support Vector Machine
  cat(" ps: svm", "\n")
  svm_ps <- train(
    a_formula,
    data = aux,
    method = "svmRadial", 
    trControl = aControl, 
    prob.model = TRUE,
    tuneGrid = expand.grid(sigma = seq(0.1, 1, 0.1), C = 1))
  
  main$svm_ps1  <- kernlab::predict(svm_ps$finalModel, main_X_a, type = "probabilities")[, "yes"]
  main$svm_ps0 <- kernlab::predict(svm_ps$finalModel, main_X_a, type = "probabilities")[, "no"]
  
  
  ###############################
  # Outcome models
  ###############################
  yControl <- trainControl(method = "cv",
                           index = myFolds,
                           summaryFunction = mnLogLoss,
                           classProbs = TRUE)
  
  #  Random Forest
  cat(" outcome: rf", "\n")
  rf_y <- train(aux_X, aux$outcome, method = "ranger",
                trControl = yControl, tuneLength = 3)
  
  
  main$rf_y0 <- predict(rf_y$finalModel, data = main0, type = "response")$predictions[, "yes"]
  main$rf_y1 <- predict(rf_y$finalModel, data = main1, type = "response")$predictions[, "yes"]
  
  
  # Support Vector Machine
  cat(" outcome: svm", "\n")
  svm_y <- train(
    y_formula,
    data = aux,
    method = "svmRadial",
    trControl = yControl,
    tuneGrid = expand.grid(sigma = c(0.005, 0.01), C = c(0.5, 1, 5, 10)))
  
  main$svm_y0 <- kernlab::predict(svm_y$finalModel, main_X_a0, type = "probabilities")[, "yes"]
  main$svm_y1 <- kernlab::predict(svm_y$finalModel, main_X_a1, type = "probabilities")[, "yes"]
  
  
  predicted_list[[k]] <- main
  
}

predicted_df <- reduce(predicted_list, bind_rows)

### CONSTRUCT PSEUDO OUTCOMES
methods <- c("rf", "svm")

for(psmod in methods){
  for(ymod in methods){
    
    predicted_df <- predicted_df %>%
      mutate(!!sym(str_c(psmod, "_", ymod, "_y0")) := !!sym(str_c(ymod, "_y0")) + (1-treat_num)/trim(!!sym(str_c(psmod, "_ps0"))) * (outcome_num - !!sym(str_c(ymod, "_y0"))),
             !!sym(str_c(psmod, "_", ymod, "_y1")) := !!sym(str_c(ymod, "_y1")) + treat_num/trim(!!sym(str_c(psmod, "_ps1"))) * (outcome_num - !!sym(str_c(ymod, "_y1"))),
             !!sym(str_c(psmod, "_", ymod, "_TE")) := !!sym(str_c(psmod, "_", ymod, "_y1")) - !!sym(str_c(psmod, "_", ymod, "_y0")))
    
  }
}


holder_ATE <- vector(mode = "list")

design <- suppressWarnings(svydesign(ids = ~ 1, data = predicted_df))

for(psmod in methods){
  for(ymod in methods){
    nameATEa <- sym(paste0(psmod, "_", ymod, "_TE"))
    modATEa <- sym(paste0(psmod, "_", ymod, "_modTE"))
    
    assign(as_string(modATEa), eval(expr(svyglm(!!nameATEa ~ 1, design = design))))
    
    df <- tibble("educ2" = (c("overall_ate")),
                 "model" = paste0(psmod, "_", ymod),
                 "est" = map(eval(expr(list(!!modATEa))), "coefficients") %>% unlist(),
                 "se" = sqrt(map(eval(expr(list(!!modATEa))), vcov) %>% unlist())) 
    
    holder <- sym(paste0(psmod, "_", ymod, "_holder_ate"))
    holder_ATE[[""]] <- df
  }}

# Plot results
final_output <- reduce(holder_ATE, bind_rows) 

# ATE
ate <- final_output %>%
  mutate(model = factor(model, levels = c("rf_rf", "rf_svm", "svm_rf", "svm_svm"),
                        labels = c("RF \nand RF", "RF \nand SVM", "SVM \nand RF", "SVM \nand SVM"))) %>%
  ggplot(aes(x = model, y = est)) +
  geom_point(size = 3.35, color='grey50') +
  geom_errorbar(aes(ymin = est - 1.96 * se, ymax = est + 1.96 * se,
                    width = 0.5), show.legend = FALSE, color='grey50') +
  scale_x_discrete(name = "\nML Algorithms used for Propensity\n Score and Outcome Models ") +
  scale_y_continuous(name = "Estimated ATE of Regime Type on \nConflict Success") +
  theme_minimal(base_size = 16) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)))  +
  theme(axis.text.x = element_text( vjust = 0.6))
ate
ggsave(path ="Figures", file = "l_original_data.jpg", width = 6.7, height = 5)


################################
###### (3) BALANCE PLOTS #######
################################


#full_sample_dropped <- readRDS("full_sample_dropped.RDS")
full_sample_case_control_dropped <- full_sample_case_control_imputed <- readRDS("full_sample_case_control_imputed.RDS")[[5]]


plotting <- full_sample_dropped %>% mutate(Mediator = factor(mediator)) %>%
  mutate("Military Quintile" = milex_quintile,
         "Regional Power Rank" = reg_power_rank,
         "Female Suffrage" = female_suffrage,
         "Average Education Level" = e_peaveduc,
         "National Capabilities" = cinc_a,
         "Conscription" = recruit,
         polity_a = factor(polity_a,
                           levels = c(0, 1),
                           labels = c("Autocracy", "Democracy")))

plot_names <- c("Military Quintile", "Regional Power Rank", "Female Suffrage", "Average Education Level", "National Capabilities")

plots <- vector(mode = "list")

for (name in plot_names) {
  variable <- sym(name)
  plots[[""]] <- eval(ggplot(plotting, aes(!!variable, color = polity_a, linetype = Mediator))) + 
    geom_density() + 
    theme_minimal() + 
    ylab("Density") + 
    # eval(ggtitle(variable)) +
    theme(plot.title = element_text(hjust = 0.5, size = 9)) +
    labs(color = "Regime Type")  +
    theme(plot.margin = unit(c(.5,.5,.5,.5), "cm"))
}

names(plots) <- 1:5

plot1 <- plots[[1]]
plot2 <- plots[[2]]
plot3 <- plots[[3]]
plot4 <- plots[[4]]
plot5 <- plots[[5]]

# conscription as binary
levels(plotting$Mediator) <- c("Mediator = 0", "Mediator = 1") 
conscription_plot <- plotting %>% mutate(Conscription = factor(recruit)) %>%
  ggplot(., aes(polity_a, fill = Conscription)) + 
  geom_bar(position = "fill", aes(alpha = 0.4)) + 
  theme_minimal() +
  facet_wrap(~ Mediator) + 
  ylab('Count') + 
  xlab("") + 
  scale_alpha(guide = 'none')


plot_all <- ggarrange(plot1, plot2, plot3, plot4, plot5, conscription_plot, nrow = 3, ncol = 2, common.legend = TRUE, legend = "bottom")

ggplot2::ggsave("Figures/densityplot.png",  width = 7, height = 8)



##############################################
######## (4) MULTIPLY ROBUST ANALYSES #########
##############################################


rm(list = ls())
library(tidyverse)
library(caret)
library(haven)
library(stringr)
library(survey)
library(splines)
library(rlang)
library(lmtest)
library(sandwich)
library(broom)
library(MASS)
library(matrixStats)
library(naniar)
library(VIM)
library(mice)
library(ggpubr)
library(haven)
library(splines)
library(survey)
library(rbw)
library(gbm)
library(xgboost)
library(BART)
library(rlang)
library(rsample)
source("zmisc.R")
set.seed(02138)


# create treatment and outcome formulas for fitting GBM models
x_vars <- vars(major_a, alliance_gg, cinc_a, cinc_b, contig_ks) #OMITTED contig_ks as NAS - increase data from 500 to 800 obs

x_vars_c <- vars(major_a, alliance_gg, cinc_a, cinc_b)
x_vars_b <- vars(contig_ks)
dem_variables <- vars(polity_a, polity_b) # don't create interactive varibale for ml
z_vars_b <- vars(recruit)
z_vars_c <- vars(female_suffrage, milex_quintile, reg_power_rank, e_peaveduc) #TO CONSIDER : PUT CINCA AND MAJOR IN Z

x_rhs <- map(x_vars, get_expr) %>% reduce(~ expr(!!.x + !!.y))
z_rhs <- map(c(z_vars_b, z_vars_c), get_expr) %>% reduce(~ expr(!!.x + !!.y))
dem_variables_rhs <- map(dem_variables, get_expr) %>% reduce(~ expr(!!.x + !!.y))

# formulas
m_form <- as.formula(expr(mediator_factor ~ !!x_rhs + !!z_rhs + treatment_factor + polity_b))
y_form <- as.formula(expr(compliance ~ !!x_rhs + !!z_rhs + treatment_factor + polity_b))
a_form <- as.formula(expr(treatment_factor ~ !!x_rhs))

### Estimators with multiple imputation
# load data
# WITH MIDS CONTROL GROUP
# imputation_all  <- readRDS("imputed_full.RDS") %>% 
#   map(., ~ mutate(.x, mediator = mediator,
#                   mediator_factor = factor(mediator, labels = c("no", "yes")),
#                   outcome = factor(if_else(compliance %in% c(2,1), 1, 0),  labels = c("yes", "no"))))

#WITH CASE CONTROL GROUP
imputation_all  <- readRDS("full_sample_case_control_imputed.RDS") %>% 
  map(., ~ mutate(.x, mediator = mediator,
                  mediator_factor = factor(mediator, labels = c("no", "yes")),
                  treatment_factor = factor(polity_a, labels = c("no", "yes")),
                  outcome = factor(if_else(compliance %in% c(2,1), 1, 0), labels = c("yes", "no"))))


# full_sample_case_control_imputed <- readRDS("full_sample_case_control_imputed.RDS")[[1]]


I <- length(imputation_all)

out <- vector(mode = "list", I)

# TO DO
for(i in 1:I){
  
  cat("imputed sample ", i, "\n")
  
  d <- imputation_all[[i]]
  
  k <- 1
  K <- 5
  whole_folds <- createFolds(d$mediator, k = K) 
  predicted_list <- vector(mode = "list", K)
  
  for (k in 1:K) {
    
    cat("fold ", k, "\n")
    aux <- d[-whole_folds[[k]], ]
    main <- d[whole_folds[[k]], ]
    
    ## Data matrices for train() and predict()
    
    # Auxiliary sample models
    aux_X <-  model.matrix(y_form, data = aux)[, -1] #outcome model only 
    aux_X_outcome <- dplyr::select(filter(aux, mediator == 1), outcome)
    aux_X_mediator <-  model.matrix(m_form, data = aux)[, -1] 
    
    # Counterfactual main sample models
    #outcome models prediction dataset
    main0 <- main1 <- main %>% dplyr::select(!!!x_vars, !!!z_vars_b, !!!z_vars_c, polity_b, mediator_factor, compliance) %>% rename(mediator_factoryes = mediator_factor)
    main_X_a <-  model.matrix(a_form, data = main)[, -1]
    
    main0$treatment_factoryes  <- 0   ## check name of variable required for RF prediction 
    main1$treatment_factoryes  <- 1
    
    
    #mediator models prediction dataset
    main0_med_mod <- main1_med_mod <- main %>% dplyr::select(!!!x_vars, !!!z_vars_b, !!!z_vars_c, polity_b)
    
    main0_med_mod$treatment_factoryes  <- 0   ## check name of variable required for RF prediction 
    main1_med_mod$treatment_factoryes  <- 1
    
    main1gbmoutcome <- main0gbmoutcome <- main0_med_mod
    main0gbmoutcome$treatment_factor <- 0
    main1gbmoutcome$treatment_factor <- 1
    
    y_form_yes <- as.formula(expr(~ !!x_rhs + !!z_rhs + treatment_factoryes + polity_b))
    main0_X <- model.matrix(y_form_yes, data = main0_med_mod)[, -1]
    main1_X <- model.matrix(y_form_yes, data = main1_med_mod)[, -1]
    
    #mediator models auxiliary matrix
    m_form_yes <- as.formula(expr( ~ !!x_rhs + !!z_rhs + treatment_factoryes + polity_b))
    aux_X_mps0 <- model.matrix(m_form_yes, data = main0_med_mod)[, -1] 
    aux_X_mps1 <- model.matrix(m_form_yes, data = main1_med_mod)[, -1] 
    
    
    ##### Lasso, ridge and e-net prep #####
    ##########################
    
    # X numeric
    numeric_confounders <- map(x_vars_c, get_expr)
    nc_vector_x <- vector(mode = "list")
    
    for (var in numeric_confounders) {
      nc_vector_x[[""]]  <- parse_expr(paste0("poly(", var,", 2, raw = TRUE)"))
    }
    
    # Z numeric
    numeric_confounders <- map(z_vars_c, get_expr)
    nc_vector_z <- vector(mode = "list")
    
    for (var in numeric_confounders) {
      nc_vector_z[[""]]  <- parse_expr(paste0("poly(", var,", 2, raw = TRUE)"))
    }
    
    dem_variables <- vars(polity_a, polity_b) # don't create interactive varibale for ml
    
    x_rhs_sq <- unlist(append(map(x_vars_b, get_expr), nc_vector_x)) %>%
      reduce(~ expr(!!.x + !!.y))
    
    z_rhs_sq <- unlist(append(map(z_vars_b, get_expr), nc_vector_z)) %>%
      reduce(~ expr(!!.x + !!.y))
    
    # A formulas and design matrices
    a_form_sq <- as.formula(expr(treatment_factor ~ (!!x_rhs_sq)^2))
    
    # Mediator model formulas and design matrices
    m_form_sq <- as.formula(expr(mediator_factor ~ (!!z_rhs_sq + !!x_rhs_sq + treatment_factor + polity_b)^2))
    
    m_form_sq_yes <- as.formula(expr( ~ (!!z_rhs_sq + !!x_rhs_sq + treatment_factoryes + polity_b)^2))
    
    main0_X_sq <- model.matrix(m_form_sq_yes, data = main0_med_mod)[, -1]
    main1_X_sq <- model.matrix(m_form_sq_yes, data = main1_med_mod)[, -1]
    
    maina_X_sq <- model.matrix(a_form_sq, data = main)[, -1]
    
    
    # Outcome formulas and design matrices
    y_form_sq <- as.formula(expr(compliance ~ (!!x_rhs_sq + !!z_rhs_sq + treatment_factor + polity_b)^2))
    y_form_sq_yes <- as.formula(expr( ~ (!!x_rhs_sq + !!z_rhs_sq + treatment_factoryes + polity_b)^2))
    
    aux_X_outcome_sq <- model.matrix(y_form_sq, data = filter(aux, mediator == 1))[, -1]
    
    main0_X_outcome_sq <- model.matrix(y_form_sq_yes, data = main0_med_mod)[, -1]
    main1_X_outcome_sq <- model.matrix(y_form_sq_yes, data = main1_med_mod)[, -1]
    
    
    ###########
    myFolds <- createFolds(aux$mediator, k = 5)
    
    ###############################
    # Propensity score models for M- probability p of receiving mediator and 1-p at a and a'
    ###############################
    aControl <- trainControl(
      summaryFunction = mnLogLoss, 
      number = 5,
      method = "cv",
      classProbs = TRUE,
      savePredictions = TRUE,
      index = myFolds
    )
    
    # Random forest
    cat(" ps_m: rf", "\n")
    rf_ps_m <- train(
      m_form,
      data = aux,
      method = "ranger", 
      trControl = aControl, 
      tuneLength = 3
    )
    
    main$rf_mps0 <- predict(rf_ps_m$finalModel, data = main0_med_mod, type = "response")$predictions[, "yes"] #mps0 refers to value of M under A=0
    main$rf_mps1 <- predict(rf_ps_m$finalModel, data = main1_med_mod, type = "response")$predictions[, "yes"]
    
    
    # Support Vector Machine
    svm_ps_m <- train(
      m_form,
      data = aux,
      method = "svmRadial",
      trControl = aControl, 
      prob.model = TRUE,
      tuneGrid = expand.grid(sigma = seq(0.1, 1, 0.1), C = 1))
    
    main$svm_mps1  <- kernlab::predict(svm_ps_m$finalModel, main1_X, type = "probabilities")[, "yes"]
    main$svm_mps0 <- kernlab::predict(svm_ps_m$finalModel, main0_X, type = "probabilities")[, "yes"]
    
    # (1) Lasso Regression
    cat(" ps: lasso", "\n")
    lasso_ps <- train(
      m_form_sq, 
      data = aux, 
      method = "glmnet",
      trControl = aControl, 
      tuneGrid = expand.grid(
        alpha = 1, 
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$lasso_mps1 <- predict(lasso_ps$finalModel, newx = main1_X_sq, s = lasso_ps$finalModel$lambdaOpt, type = "response") 
    main$lasso_mps0 <- predict(lasso_ps$finalModel, newx = main0_X_sq, s = lasso_ps$finalModel$lambdaOpt, type = "response") 
    
    # (2) Ridge Regression
    cat(" ps: ridge", "\n")
    ridge_ps <- train(
      m_form_sq, 
      data = aux, 
      method = "glmnet",
      trControl = aControl, 
      tuneGrid = expand.grid(
        alpha = 0, 
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$ridge_mps1 <- predict(ridge_ps$finalModel, newx = main1_X_sq, s = ridge_ps$finalModel$lambdaOpt, type = "response")
    main$ridge_mps0 <-predict(ridge_ps$finalModel, newx = main0_X_sq, s = ridge_ps$finalModel$lambdaOpt, type = "response")
    
    # (3) Elastic Net
    cat(" ps: elastic_net", "\n")
    elastic_net_ps <- train(
      m_form_sq,
      data = aux,
      method = "glmnet",
      trControl = aControl,
      tuneGrid = expand.grid(
        alpha = 0.5,
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$elastic_net_mps1 <- predict(elastic_net_ps$finalModel, newx = main1_X_sq, s = elastic_net_ps$finalModel$lambdaOpt, type = "response")
    main$elastic_net_mps0 <- predict(elastic_net_ps$finalModel, newx = main0_X_sq, s = elastic_net_ps$finalModel$lambdaOpt, type = "response")
    
    
    # Gradient Boosting Machine
    cat(" outcome: gbm", "\n")
    invisible(capture.output(
      gbm_ps_m <- train(
        aux_X_mediator,
        aux$mediator_factor,
        method = "gbm",
        distribution = "bernoulli",
        trControl = aControl,
        tuneLength = 10,
        
      )
    ))
    
    main$gbm_mps0  <- 1 - predict(gbm_ps_m$finalModel, newdata = main0_X, n.trees = gbm_ps_m$finalModel$n.trees, type = "response")
    main$gbm_mps1  <- 1 - predict(gbm_ps_m$finalModel, newdata = main1_X, n.trees = gbm_ps_m$finalModel$n.trees, type = "response")
    
    
    ###############################
    # Propensity score models for A- probability p of A
    ###############################
    
    myFolds <- createFolds(aux$treatment_factor, k = 5)
    
    aControl <- trainControl(
      summaryFunction = mnLogLoss, 
      number = 5,
      method = "cv",
      classProbs = TRUE,
      savePredictions = TRUE,
      index = myFolds
    )
    
    # Random forest
    cat(" ps_a: rf", "\n")
    rf_ps_a <- train(
      a_form,
      data = aux,
      method = "ranger", 
      trControl = aControl, 
      tuneLength = 3
    )
    
    main$rf_aps0 <- predict(rf_ps_a$finalModel, data = main, type = "response")$predictions[, "no"]
    main$rf_aps1 <- predict(rf_ps_a$finalModel, data = main, type = "response")$predictions[, "yes"]
    
    
    # (=Support Vector Machine
    svm_ps_a <- train(
      a_form,
      data = aux,
      method = "svmRadial", 
      trControl = aControl, 
      prob.model = TRUE,
      tuneGrid = expand.grid(sigma = seq(0.1, 1, 0.1), C = 1))
    
    main$svm_aps1  <- kernlab::predict(svm_ps_a$finalModel, main_X_a, type = "probabilities")[, "yes"]
    main$svm_aps0 <- kernlab::predict(svm_ps_a$finalModel, main_X_a, type = "probabilities")[, "no"]
    
    # Gradient Boosting Machine
    gbm_ps_a <- suppressWarnings(train(
      a_form,
      data = aux,
      method = "gbm",
      distribution = "bernoulli",
      trControl = aControl,
      tuneLength = 10))
    
    main$gbm_aps0  <- predict(gbm_ps_a$finalModel, newdata = main_X_a, n.trees = gbm_ps_a$finalModel$n.trees,
                              type = "response")
    main$gbm_aps1 <- 1-main$gbm_aps0
    
    
    # (1) Lasso Regression
    cat(" ps: lasso", "\n")
    lasso_ps <- train(
      a_form_sq, 
      data = aux, 
      method = "glmnet",
      trControl = aControl, 
      tuneGrid = expand.grid(
        alpha = 1, 
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$lasso_aps1 <- predict(lasso_ps$finalModel, newx = maina_X_sq, s = lasso_ps$finalModel$lambdaOpt, type = "response") 
    main$lasso_aps0 <- 1-main$lasso_aps1
    
    # (2) Ridge Regression
    cat(" ps: ridge", "\n")
    ridge_ps <- train(
      a_form_sq, 
      data = aux, 
      method = "glmnet",
      trControl = aControl, 
      tuneGrid = expand.grid(
        alpha = 0, 
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$ridge_aps1 <- predict(ridge_ps$finalModel, newx = maina_X_sq, s = ridge_ps$finalModel$lambdaOpt, type = "response")
    main$ridge_aps0 <- 1-main$ridge_aps1
    
    # (3) Elastic Net
    cat(" ps: elastic_net", "\n")
    elastic_net_ps <- train(
      a_form_sq,
      data = aux,
      method = "glmnet",
      trControl = aControl,
      tuneGrid = expand.grid(
        alpha = 0.5,
        lambda = 10^seq(1, -4, length=100))
    )
    
    main$elastic_net_aps1 <- predict(elastic_net_ps$finalModel, newx = maina_X_sq, s = elastic_net_ps$finalModel$lambdaOpt, type = "response")
    main$elastic_net_aps0 <- 1-main$elastic_net_aps1
    
    
    ###############################
    # Outcome models
    ###############################
    m1_folds <- createFolds(aux$outcome[aux$mediator==1], k = 5)
    yControl <- trainControl(method = "cv",
                             index = m1_folds,
                             summaryFunction = mnLogLoss,
                             classProbs = TRUE)
    
    
    #  Random Forest
    cat(" outcome: rf", "\n")
    rf_y <- train(aux_X, aux$outcome[aux$mediator==1], method = "ranger",
                  trControl = yControl, tuneLength = 3)
    
    main$rf_y0 <- predict(rf_y$finalModel, data = main0, type = "response")$predictions[, "yes"] # ie treamtent = no
    main$rf_y1 <- predict(rf_y$finalModel, data = main1, type = "response")$predictions[, "yes"] #treatment = yes
    
    
    # Support Vector Machine
    cat(" outcome: svm", "\n")
    svm_y <- train(
      aux_X,
      aux$outcome[aux$mediator==1],
      method = "svmRadial", 
      trControl = yControl,
      tuneGrid = expand.grid(sigma = c(0.005, 0.01), C = c(0.5, 1, 5, 10)))
    
    main$svm_y0 <- kernlab::predict(svm_y$finalModel, main0_X, type = "probabilities")[, "yes"]
    main$svm_y1 <- kernlab::predict(svm_y$finalModel, main1_X, type = "probabilities")[, "yes"]
    
    # (1) Lasso Regression
    cat(" outcome: lasso", "\n")
    lasso_y <- train(
      aux_X_outcome_sq,
      aux$outcome[aux$mediator==1], 
      method = "glmnet",
      trControl = yControl, 
      tuneGrid = expand.grid(
        alpha = 0:1, 
        lambda = seq(0.0001, 50, length = 0.1))
    )
    
    main$lasso_y0 <- predict(lasso_y$finalModel, newx = main0_X_outcome_sq, s = lasso_y$finalModel$lambdaOpt, type = "response") 
    main$lasso_y1 <- predict(lasso_y$finalModel, newx = main1_X_outcome_sq, s = lasso_y$finalModel$lambdaOpt, type = "response") 
    
    # (2) Ridge Regression
    cat(" outcome: ridge", "\n")
    ridge_y <- train(
      aux_X_outcome_sq,
      aux$outcome[aux$mediator==1], 
      method = "glmnet",
      trControl = yControl, 
      tuneGrid = expand.grid(
        alpha = 0, 
        lambda = seq(0.0001, 50, length = 0.1))
    )
    
    main$ridge_y0 <- predict(ridge_y$finalModel, newx = main0_X_outcome_sq,
                             s = ridge_y$finalModel$lambdaOpt, type = "response") 
    
    main$ridge_y1 <- predict(ridge_y$finalModel, newx = main1_X_outcome_sq,
                             s = ridge_y$finalModel$lambdaOpt, type = "response")
    
    # (3) Elastic Net
    cat(" outcome: elastic_net", "\n")
    elastic_net_y <- train(
      aux_X_outcome_sq,
      aux$outcome[aux$mediator==1], 
      method = "glmnet",
      trControl = yControl,
      tuneGrid = expand.grid(
        alpha = c(0.5),
        lambda = seq(0.0001, 50, length = 0.1))
    )
    
    main$elastic_net_y0 <- predict(elastic_net_y$finalModel, newx = main0_X_outcome_sq,
                                   s = elastic_net_y$finalModel$lambdaOpt, type = "response")
    main$elastic_net_y1 <- predict(elastic_net_y$finalModel, newx = main1_X_outcome_sq,
                                   s = elastic_net_y$finalModel$lambdaOpt, type = "response")
    
    # Gradient Boosting Machine
    newd = filter(aux, mediator == 1) %>% dplyr::select(-c(mediator_factor, mediator))
    gbm_y <- gbm(y_form, data = newd, interaction.depth = 5, cv.folds = 5,
                 n.trees = 300, distribution = "bernoulli")
    best_iter_m <- gbm.perf(gbm_y, method = "cv", plot.it = FALSE)
    
    main$gbm_y1 <-  predict(gbm_y, newdata = main1gbmoutcome, n.trees = best_iter_m, type = "response")
    main$gbm_y0  <- predict(gbm_y, newdata = main0gbmoutcome, n.trees = best_iter_m, type = "response")
    
    
    predicted_list[[k]] <- main
    
  }
  
  
  d2 <- predicted_df <- reduce(predicted_list, bind_rows) %>% filter(mediator == 1)
  d2$lasso_y0[duplicated(d2$lasso_y0)] <- d2$lasso_y0[duplicated(d2$lasso_y0)]+0.001
  d2$lasso_y1[duplicated(d2$lasso_y1)] <- d2$lasso_y1[duplicated(d2$lasso_y1)]+0.001
  d2$ridge_y0[duplicated(d2$lasso_y0)] <- d2$ridge_y0[duplicated(d2$ridge_y0)]+0.001
  d2$ridge_y1[duplicated(d2$lasso_y1)] <- d2$ridge_y1[duplicated(d2$ridge_y1)]+0.001
  d2$elastic_net_y0[duplicated(d2$lasso_y0)] <- d2$elastic_net_y0[duplicated(d2$elastic_net_y0)]+0.001
  d2$elastic_net_y1[duplicated(d2$lasso_y1)] <- d2$elastic_net_y1[duplicated(d2$elastic_net_y1)]+0.001
  
  
  #################################
  ### Two-step estimator of \nu_y(X)
  ##################################
  
  y_methods <- c("rf", "svm", "gbm")
  
  ######################################
  ######## (1)  Random Forest   ########
  ######################################
  
  for (ymethods in y_methods) {
    nu_form_y0 <- as.formula(expr(!!sym(paste0(ymethods, "_y0")) ~ !!x_rhs + polity_b)) #ie have already prediced outcome at M=1 when A=a and =a' - imputed for all units
    nu_form_y1 <- as.formula(expr(!!sym(paste0(ymethods, "_y1")) ~ !!x_rhs + polity_b))
    
    k <- 1
    K <- 5
    whole_folds <- createFolds(d2$outcome, k = K) 
    predicted_list_nu <- vector(mode = "list", K)
    
    for (k in 1:K) {
      
      cat("fold ", k, "\n")
      aux <- d2[-whole_folds[[k]], ]
      main <- d2[whole_folds[[k]], ]
      
      ## Data matrices for train() and predict()
      # Auxiliary sample models
      aux_X_y0 <-  model.matrix(nu_form_y0, data = aux)[, -1] #outcome model only 
      aux_X_y1 <-  model.matrix(nu_form_y1, data = aux)[, -1] #outcome model only 
      
      
      myFolds <- createFolds(aux$outcome, k = 5)
      
      yControl <- trainControl(method = "cv",
                               index = myFolds,
                               summaryFunction = defaultSummary)
      
      cat(" nu:", ymethods, "_rf_y0", "\n")
      eval(parse_expr(paste0("y0_vec <- aux$", ymethods, "_y0")))
      
      rf_nuy_y0 <- train(aux_X_y0, y0_vec, method = "ranger",
                         trControl = yControl, tuneLength = 3)
      
      eval(parse_expr(paste0("main$", ymethods, "_rf_y0 <- predict(rf_nuy_y0$finalModel, data = main)$predictions")))  # where rf_rf refers to outcome then nu model!
      
      
      cat(" nu:", ymethods, "_rf_y1", "\n")
      eval(parse_expr(paste0("y1_vec <- aux$", ymethods, "_y1")))
      
      rf_nuy_y1 <- train(aux_X_y0, y1_vec, method = "ranger",
                         trControl = yControl, tuneLength = 3)
      
      eval(parse_expr(paste0("main$",  ymethods, "_rf_y1 <- predict(rf_nuy_y1$finalModel, data = main)$predictions"))) 
      
      
      predicted_list_nu[[k]] <- main
      
      
    }
    
    d2 <- reduce(predicted_list_nu, bind_rows) 
  }
  
  ######################################
  ##### Support Vector Machine #########
  ######################################
  
  for (ymethods in y_methods) {
    nu_form_y0 <- as.formula(expr(!!sym(paste0(ymethods, "_y0")) ~ !!x_rhs + polity_b)) 
    nu_form_y1 <- as.formula(expr(!!sym(paste0(ymethods, "_y1")) ~ !!x_rhs + polity_b))
    
    k <- 1
    K <- 5
    whole_folds <- createFolds(d2$outcome, k = K) 
    predicted_list_nu <- vector(mode = "list", K)
    
    for (k in 1:K) {
      
      cat("fold ", k, "\n")
      aux <- d2[-whole_folds[[k]], ]
      main <- d2[whole_folds[[k]], ]
      
      main_X <- model.matrix(nu_form_y0, data = main)[, -1]
      
      ## Data matrices for train() and predict()
      # Auxiliary sample models
      aux_X_y0 <-  model.matrix(nu_form_y0, data = aux)[, -1] #outcome model only 
      aux_X_y1 <-  model.matrix(nu_form_y1, data = aux)[, -1] #outcome model only 
      
      
      myFolds <- createFolds(aux$outcome, k = 5)
      
      yControl <- trainControl(method = "cv",
                               index = myFolds,
                               summaryFunction = defaultSummary)
      
      
      cat(" nu:", ymethods, "_svm_y0", "\n")
      eval(parse_expr(paste0("y0_vec <- aux$", ymethods, "_y0")))
      svm_nu_y0 <- train(
        aux_X_y0,
        y0_vec,
        method = "svmRadial", 
        trControl = yControl,
        tuneGrid = expand.grid(sigma = c(0.005, 0.01), C = c(0.5, 1, 5, 10)))
      
      eval(parse_expr(paste0("main$", ymethods, "_svm_y0 <- kernlab::predict(svm_nu_y0$finalModel, main_X)")))
      
      cat(" nu:", ymethods, "_svm_y1", "\n")
      eval(parse_expr(paste0("y1_vec <- aux$", ymethods, "_y1")))
      svm_nu_y1 <- train(
        aux_X_y0,
        y1_vec,
        method = "svmRadial",
        trControl = yControl,
        tuneGrid = expand.grid(sigma = c(0.005, 0.01), C = c(0.5, 1, 5, 10)))
      
      eval(parse_expr(paste0("main$", ymethods, "_svm_y1 <- kernlab::predict(svm_nu_y1$finalModel, main_X)"))) 
      
      
      predicted_list_nu[[k]] <- main
      
    }
    
    d2 <- reduce(predicted_list_nu, bind_rows) 
  }
  
  
  
  ######################################
  ################## GBM ###############
  ######################################
  
  for (ymethods in y_methods) {
    nu_form_y0 <- as.formula(expr(!!sym(paste0(ymethods, "_y0")) ~ !!x_rhs + polity_b)) 
    nu_form_y1 <- as.formula(expr(!!sym(paste0(ymethods, "_y1")) ~ !!x_rhs + polity_b))
    
    k <- 1
    K <- 5
    whole_folds <- createFolds(d2$outcome, k = K) 
    predicted_list_nu <- vector(mode = "list", K)
    
    for (k in 1:K) {
      
      cat("fold ", k, "\n")
      aux <- d2[-whole_folds[[k]], ]
      main <- d2[whole_folds[[k]], ]
      
      main_X <- model.matrix(nu_form_y0, data = main)[, -1]
      
      ## Data matrices for train() and predict()
      # Auxiliary sample models
      aux_X_y0 <-  model.matrix(nu_form_y0, data = aux)[, -1] #outcome model only 
      aux_X_y1 <-  model.matrix(nu_form_y1, data = aux)[, -1] #outcome model only 
      
      
      myFolds <- createFolds(aux$outcome, k = 5)
      
      yControl <- trainControl(method = "cv",
                               index = myFolds,
                               summaryFunction = defaultSummary)
      
      
      cat(" nu:", ymethods, "_gbm_y0", "\n")
      gbm_nu_y0 <- gbm(nu_form_y0, data = aux, interaction.depth = 5, cv.folds = 5,
                       n.trees = 300, distribution = "gaussian")
      best_iter_m <- gbm.perf(gbm_nu_y0, method = "cv", plot.it = FALSE)
      
      eval(parse_expr(paste0("main$", ymethods, "_gbm_y0 <- predict(gbm_nu_y0, newdata = main, n.trees = best_iter_m)")))
      
      cat(" nu:", ymethods, "_gbm_y1", "\n")
      gbm_nu_y1 <- gbm(nu_form_y1, data = aux, interaction.depth = 5, cv.folds = 5,
                       n.trees = 300, distribution = "gaussian")
      best_iter_m <- gbm.perf(gbm_nu_y1, method = "cv", plot.it = FALSE)
      
      eval(parse_expr(paste0("main$", ymethods, "_gbm_y1 <- predict(gbm_nu_y1, newdata = main, n.trees = best_iter_m)")))
      
      
      predicted_list_nu[[k]] <- main
      
    }
    
    d2 <- reduce(predicted_list_nu, bind_rows) 
  }
  
  
  
  d3 <- d2
  
  #   # ######################################
  #   # ############ Lasso  ########
  #   # ######################################
  #  
  # for (ymethods in y_methods) {
  #   nu_form_y0 <- as.formula(expr(!!sym(paste0(ymethods, "_y0")) ~ !!x_rhs + polity_b))
  #   nu_form_y1 <- as.formula(expr(!!sym(paste0(ymethods, "_y1")) ~ !!x_rhs + polity_b))
  # 
  # k <- 1
  #  K <- 5
  # whole_folds <- createFolds(d2$outcome, k = K)
  #  predicted_list_nu <- vector(mode = "list", K)
  # 
  # for (k in 1:K) {
  #   
  #  cat("fold ", k, "\n")
  #    aux <- d2[-whole_folds[[k]], ]
  #   main <- d2[whole_folds[[k]], ]
  # 
  #     ## Data matrices for train() and predict()
  #     # Auxiliary sample models
  #     aux_X_y0 <-  model.matrix(nu_form_y0, data = aux)[, -1] #outcome model only
  #     aux_X_y1 <-  model.matrix(nu_form_y1, data = aux)[, -1] #outcome model only
  # 
  #     myFolds <- createFolds(aux$outcome, k = 5)
  # 
  #     yControl <- trainControl(method = "cv",
  #                              index = myFolds,
  #                              summaryFunction = defaultSummary,
  #                              savePredictions = "all")
  # 
  #     metric = 'RMSE'
  #     
  #     cat(" outcome: lasso_nu_y0", "\n")
  # 
  #     eval(parse_expr(paste0("y0_vec <- aux$", ymethods, "_y0")))
  # 
  #     lasso_nu_y0 <- train(aux_X_y0, y0_vec, method = "glmnet",
  #       trControl = yControl,
  #       tuneGrid = expand.grid(
  #         alpha = 1,
  #         lambda = 10^seq(2, -4, length=100))
  #     )
  # 
  #     main$lasso_nu_y0 <- predict(lasso_nu_y0$finalModel, newx = main0_X_outcome_sq, s = lasso_nu_y0$finalModel$lambdaOpt, type = "response")
  # 
  #     eval(parse_expr(paste0("main$", ymethods, "_lasso_y0 <- predict(lasso_nu_y0$finalModel, newx = main0_X_outcome_sq, s = lasso_nu_y0$finalModel$lambdaOpt)")))
  
  
  #######
  
  # 
  #     main$lasso_nu_y0 <- predict(lasso_ny_y0$finalModel, data = main)$predictions
  # 
  #     cat(" outcome: rf_nu_y1", "\n")
  #     lasso_ny_y1 <- train(aux_X_y0, aux$lasso_y1, method = "ranger",
  #                       trControl = yControl, tuneLength = 3)
  # 
  #     main$lasso_nu_y1 <- predict(lasso_ny_y1$finalModel, data = main)$predictions
  # 
  #     predicted_list_nu[[k]] <- main
  # 
  # 
  #   }
  # 
  #       d2 <- reduce(predicted_list_nu, bind_rows)
  # }
  
  
  
  d3 <- predicted_df_nu <- d2 %>% mutate(outcome_num = if_else(outcome == "yes", 1, 0))
  
  # ### Triply robust estimator signals
  
  ### With nu combos
  print(y_methods)
  nu_methods <- c("rf", "svm", "gbm")
  # y_methods defined above!!
  a_methods <- m_methods <- c("rf", "svm", "lasso", "elastic_net", "ridge", "gbm")
  y_methods <-  c("rf", "svm", "gbm")
  
  for(a_psmod in a_methods){
    for(m_psmod in m_methods) {
      for(y_mod in y_methods){
        for(nu_ymod in nu_methods){
          
          predicted_df_nu <- predicted_df_nu %>% 
            mutate(!!sym(str_c(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y0")) := !!sym(str_c(y_mod, "_", nu_ymod, "_y0")) +
                     as.double(outcome == "no")/trim(!!sym(str_c(a_psmod, "_aps0")))*(!!sym(str_c(y_mod, "_y0")) - !!sym(str_c(y_mod, "_", nu_ymod, "_y0"))) +
                     as.double(outcome == "no")/(trim(!!sym(str_c(a_psmod, "_aps0")))*!!sym(str_c(m_psmod, "_mps0")))*(outcome_num-!!sym(str_c(y_mod, "_y0"))),
                   !!sym(str_c(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y1")) := !!sym(str_c(y_mod, "_", nu_ymod, "_y1")) +
                     as.double(outcome == "yes")/trim(!!sym(str_c(a_psmod, "_aps1")))*(!!sym(str_c(y_mod, "_y1")) - !!sym(str_c(y_mod, "_", nu_ymod, "_y1"))) +
                     as.double(outcome == "yes")/(trim(!!sym(str_c(a_psmod, "_aps1")))*!!sym(str_c(m_psmod, "_mps1")))*(outcome_num-!!sym(str_c(y_mod, "_y1"))),
                   !!sym(str_c(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_TE")) := !!sym(str_c(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y1")) - !!sym(str_c(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y0")))
        }}}}
  
  
  
  holder_CATE <- vector(mode = "list")
  
  design <- suppressWarnings(svydesign(ids = ~ 1, data = predicted_df_nu))
  
  for(a_psmod in a_methods){
    for(m_psmod in m_methods) {
      for(y_mod in y_methods){
        for(nu_ymod in nu_methods) {
          
          name0a <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y0"))
          name1a <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_y1"))
          nameATEa <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_TE"))
          
          mod0a <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_mod0"))
          mod1a <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_mod1"))
          modATEa <- sym(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod, "_modTE"))
          
          assign(as_string(mod0a), eval(expr(svyglm(!!name0a ~ 1, design = design))))
          assign(as_string(mod1a), eval(expr(svyglm(!!name1a ~ 1, design = design))))
          assign(as_string(modATEa), eval(expr(svyglm(!!nameATEa ~ 1, design = design))))
          
          df <- tibble("estimand" = (c("overall_ate")),
                       "model" = c(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod)),
                       "est" = map(eval(expr(list(!!modATEa))), "coefficients") %>% unlist(),
                       "se" = sqrt(map(eval(expr(list(!!modATEa))), vcov) %>% unlist()))
          holder_CATE[[""]] <- df
          
          
          # CATE by dyad type
          assign(as_string(mod0a), eval(expr(svyglm(!!name0a ~ polity_b, design = design)))) # I think this is just assigning 
          assign(as_string(mod1a), eval(expr(svyglm(!!name1a ~ polity_b, design = design))))
          assign(as_string(modATEa), eval(expr(svyglm(!!nameATEa ~ polity_b, design = design))))
          
          newdata <- data.frame(polity_b = c(0, 1))
          newX <- model.matrix(~ polity_b, data = newdata) %>% as.data.frame() %>%
            mutate(polity_b = factor(polity_b))
          
          ate <- eval(expr(predict(!!modATEa, newX)))  %>% as.data.frame() %>%
            `rownames<-`(c("b_autocracy_ate", "b_democracy_ate"))
          
          df3 <- ate %>% rownames_to_column() %>%
            transmute(estimand = rowname, model = c(paste0(a_psmod, "_", m_psmod, "_", y_mod, "_", nu_ymod)), est = link, se = SE)
          
          holder_CATE[[""]] <- df3
        }}}}
  
  output <- as.tibble(reduce(holder_CATE, bind_rows))
  
  
  # # weighting estimating equations fit to adjusted imputations
  # design <- svydesign(ids = ~ 1, weights = ~ weight, data = main_df)
  # mod0 <- svyglm(imp_y0_adj ~ origin, design = design)
  # mod1 <- svyglm(imp_y1_adj ~ origin, design = design)
  # 
  # # save output in holder[[i]]
  # holder[[i]] <- list(mod0, mod1) %>%
  #   set_names(0:1) %>%
  #   enframe(name = "ba", value = "model") %>%
  #   mutate(est = map(model, coef),
  #          vcov = map(model, vcov))
  
  
  out[[i]] <- output
}


out_df <- out %>%
  reduce(bind_rows) %>%
  group_by(estimand, model) %>%
  summarise(within_var = mean(se^2),
            between_var = var(est),
            total_var = within_var + (1 + 1/I)*between_var,
            est = mean(est),
            se = sqrt(total_var)) %>%
  ungroup()

out_df$model <- gsub("elastic_net", "elasticnet", out_df$model)
out_df2 <- out_df %>%
  separate(col = model, into = c("psmod", "m_psmod", "y_mod", "nu_ymod"), sep = "_") 

no_mods_vec <- c('lasso', 'ridge', 'elasticnet', 'svm')
no_mods_vec <- c('ridge')

#save.image(file = "multiply_imputation.RData")

### ATE PLOT
ate_out <- out_df2 %>%
  filter(estimand == "overall_ate") %>%
  filter(psmod == m_psmod, psmod!=no_mods_vec) %>% dplyr::slice ( c(1, 18, 19, 31, 37, 47)) %>%
  mutate(mod = factor(psmod, levels = c("rf", "lasso", "elasticnet", "ridge", "gbm", "svm"),
                      labels = c("Random Forest", "Lasso", "Elastic Net", "Ridge", "GBM", "SVM"))) %>%
  ggplot(aes(x = mod, y = est)) +
  geom_point(size = 3.35, color='grey50') +
  geom_errorbar(aes(ymin = est - 1.96 * se, ymax = est + 1.96 * se,
                    width = 0.5), show.legend = FALSE, color='grey50') +
  scale_x_discrete(name = "\nML Algorithm Used for\nAll Nuisance Functions", ) +
  scale_y_continuous(name = "CDE Estimate (Probability Increase)\nand 95% CIs" , limits = c(0, 10)) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 10, r = 10, b = 10, l = 10)),
        axis.title.x = element_text(margin = margin(t = 10, r = 10, b = 10, l = 10)),
        plot.margin = unit(c(20,20,.5,.5), "cm")) +
  theme_minimal(base_size = 15)
ggsave(path = "Figures", filename = "ate_equal.png", width = 6, height = 5)

## CATE PLOT
cate_out <- out_df2 %>%
  filter(estimand != "overall_ate") %>%
  filter(psmod == m_psmod ) %>% 
  arrange(psmod, m_psmod, y_mod, nu_ymod, estimand)  %>% dplyr::slice( c(7, 8, 24, 25, 43, 44, 63, 64, 97, 98)) %>%
  mutate(mod = factor(psmod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm"),
                      labels = c("Random Forest", "Lasso", "Elastic Net", "SVM", "GBM")),
         Estimand = factor(estimand, levels = c("b_democracy_ate", "b_autocracy_ate"),
                           labels = c("Democracy", "Autocracy"))) %>%
  ggplot(aes(x = Estimand, y = est)) +
  geom_point(size = 3.35, aes(color= Estimand)) +
  geom_errorbar(aes(ymin = est - 1.96 * se, ymax = est + 1.96 * se,
                    width = 0.5, color = Estimand), show.legend = FALSE) +
  scale_x_discrete(name = "") +
  scale_y_continuous(name = "CDE Estimate (Probability Increase)\nand 95% CIs", limits = c(0, 25)) + 
  theme_minimal(base_size = 15) + 
  theme(legend.position = "bottom",
        plot.title = element_text(hjust = 0.5),
        plot.caption = element_text(color = "grey30"),
        axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)),
        axis.text.x=element_blank()) +
  labs(color = "State B Regime \nStatus") +
  facet_wrap(~mod) +
  theme(legend.position = c(0.85, 0.35),
        legend.direction = "vertical",
        legend.title = element_text( size = 13),
        legend.text= element_text( size = 11)) +
  scale_color_manual(values = alpha(c("dodgerblue", "seagreen3"), 0.84)) +
  scale_fill_manual(values = alpha(c("dodgerblue", "seagreen3"), 0.84))
ggsave(path = "Figures", filename = "cate_equal.png", width = 6, height = 5)




####### ####### ####### ####### ####### ####### ####### #####
######### Appendix extra ML ALGORITHMS #######
####### ####### ####### ####### ####### ####### #######
library(kableExtra)
overall_appendix <- filter(out_df2, estimand == "overall_ate") %>%
  mutate("Treatment\nModel" = factor(psmod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                     labels = c("RF", "Lasso", "E-Net", "SVM", "GBM", "Ridge"))) %>%
  mutate("Mediator\nModel" = factor(m_psmod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                    labels = c("RF", "Lasso", "E-Net", "SVM", "GBM", "Ridge"))) %>%
  mutate("Outcome\nModel" = factor(y_mod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                   labels = c("RF", "Lasso", "E-Net", "SVM", "GBM", "Ridge")))  %>%
  mutate("Predicted\nModel" = factor(nu_ymod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                     labels = c("RF", "Lasso", "E-Net",  "SVM", "GBM", "Ridge"))) %>%
  dplyr::select(-c(within_var, between_var, total_var, estimand, psmod, m_psmod, y_mod, nu_ymod)) %>% filter(est < 11) %>% sample_n(., 20) %>%
  rename(Estimate = est, "Standard\nError" = se) %>%
  dplyr::select("Treatment\nModel", "Mediator\nModel" , "Outcome\nModel", "Predicted\nOutcome\nModel",
                Estimate, "Standard\nError")
xtable(overall_appendix, type = "latex", file = "filename2.tex")


conditional_appendixa <- filter(out_df2, estimand != "overall_ate") %>%
  mutate("Treatment Model" = factor(psmod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                    labels = c("RF", "Lasso", "ENet", "SVM", "GBM", "Ridge"))) %>%
  mutate("Mediator Model" = factor(m_psmod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                   labels = c("RF", "Lasso", "ENet", "SVM", "GBM", "Ridge"))) %>%
  mutate("Outcome Model" = factor(y_mod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                  labels = c("RF", "Lasso", "ENet", "SVM", "GBM", "Ridge")))  %>%
  mutate("Predicted Model" = factor(nu_ymod, levels = c("rf", "lasso", "elasticnet", "svm", "gbm", "ridge"),
                                    labels = c("RF", "Lasso", "ENet",  "SVM", "GBM", "Ridge"))) %>%
  dplyr::select(-c(within_var, between_var, total_var, psmod, m_psmod, y_mod, nu_ymod)) %>%
  unite(new, "Treatment Model", "Mediator Model", "Outcome Model", "Predicted Model", sep = "_") 

selection = conditional_appendixa %>% filter(estimand != "b_autocracy_ate") %>% filter(est < 17.2) %>% sample_n(., 20) 
selectionvec = selection$new
conditional_appendixb <- conditional_appendixa %>% filter(new %in% selectionvec) %>% arrange(new, estimand) %>%
  separate(., new, into = c("Treatment Model", "Mediator Model", "Outcome Model", "Predicted Model", sep = "_")) 

conditional_appendixc <- conditional_appendixb %>% rename(Estimate = est, SE = se) %>%
  mutate("Regime B Status" = factor(estimand, levels = c("b_autocracy_ate", "b_democracy_ate"), labels = c("Autocracy", "Democracy"))) %>%
  dplyr::select("Regime B Status", "Treatment Model", "Mediator Model" , "Outcome Model", "Predicted Model",
                Estimate, SE) 
xtable(conditional_appendixc, type = "latex", file = "filename4.tex")
